/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/* !
 * \file where_c310_impl.h
 * \brief
 */
#ifndef IMPL_MATH_WHERE_WHERE_C310_IMPL_H
#define IMPL_MATH_WHERE_WHERE_C310_IMPL_H
#include "kernel_tensor.h"
#include "../../common/check.h"
#include "../../api_check/kernel_api_check.h"

namespace AscendC {
namespace WhereInternal {
template <uint32_t size = sizeof(uint8_t)> struct ExtractDataTypeBySize {
    using T = uint8_t;
};

template <> struct ExtractDataTypeBySize<sizeof(uint16_t)> {
    using T = uint16_t;
};

template <> struct ExtractDataTypeBySize<sizeof(uint32_t)> {
    using T = uint32_t;
};
}
template <bool src0Val, bool src1Val, typename T, typename V, const MicroAPI::RegTrait& regTrait = MicroAPI::RegTraitNumOne>
__simd_vf__ inline void WhereCompute(__local_mem__ T* dstUb, __local_mem__ T* src0Ub,
    __local_mem__ T* src1Ub, const T src0, const T src1, __local_mem__ V* conditionUb,
    uint32_t count, const uint16_t repeatTime)
{
    constexpr uint32_t repeatElm = regTrait.REG_NUM * GetVecLen() / sizeof(T);
    MicroAPI::RegTensor<T, regTrait> src0Reg, src1Reg, dstReg;
    MicroAPI::RegTensor<uint8_t> selReg;
    MicroAPI::MaskReg maskReg, selMask;
    MicroAPI::MaskReg maskFull = MicroAPI::CreateMask<uint8_t>();

    if constexpr (src0Val) {
        MicroAPI::Duplicate(src0Reg, src0);
    }
    if constexpr (src1Val) {
        MicroAPI::Duplicate(src1Reg, src1);
    }
    for (uint16_t i = 0; i < repeatTime; ++i) {
        maskReg = MicroAPI::UpdateMask<T, regTrait>(count);
        MicroAPI::DataCopy(selReg, (__local_mem__ uint8_t*)conditionUb + i * repeatElm);
        MicroAPI::CompareScalar<uint8_t, CMPMODE::NE>(selMask, selReg, static_cast<uint8_t>(0), maskFull);
        if constexpr (sizeof(T) == 2) {
            MicroAPI::MaskUnPack(selMask, selMask);
        } else if constexpr (sizeof(T) == 4 || sizeof(T) == 8) {
            MicroAPI::MaskUnPack(selMask, selMask);
            MicroAPI::MaskUnPack(selMask, selMask);
        }

        if constexpr (!src0Val) {
            MicroAPI::DataCopy(src0Reg, src0Ub + i * repeatElm);
        }
        if constexpr (!src1Val) {
            MicroAPI::DataCopy(src1Reg, src1Ub + i * repeatElm);
        }
        MicroAPI::Select(dstReg, src0Reg, src1Reg, selMask);
        MicroAPI::DataCopy(dstUb + i * repeatElm, dstReg, maskReg);
    }
}

template <typename T, typename U, typename S, typename V>
__aicore__ inline void WhereImpl(const LocalTensor<T>& dst, const U& src0, const S& src1,
    const LocalTensor<V>& condition, const uint32_t count)
{
    static_assert(SupportType<T, bool, int8_t, uint8_t, int16_t, uint16_t, half, bfloat16_t, int32_t, uint32_t, float, int64_t, uint64_t>(),
    "Where only supports bool/int8_t/uint8_t/int16_t/uint16_t/half/bfloat16_t/int32_t/uint32_t/float/int64_t/uint64_t data type on current device");
    static_assert(SupportType<V, bool>(), "Where's argument of condition only supports bool data type on current device");

    CHECK_FUNC_HIGHLEVEL_API(Where, (T, U, S, V), (dst, src0, src1, condition, count));
    using WhereType = typename WhereInternal::ExtractDataTypeBySize<sizeof(T)>::T;

    __local_mem__ V *conditionUb = (__local_mem__ V *)condition.GetPhyAddr();
    uint16_t repeatTime = static_cast<uint16_t>(CeilDivision(count, GetVecLen() / sizeof(T)));
    if constexpr (TypeUtils::IsLocalTensorType<U, S>()) {
        static_assert(Std::is_same<U, S>::value);
        static_assert(Std::is_same<T, typename U::PrimType>::value);
        if constexpr (sizeof(T) != 8) {
            WhereCompute<false, false, WhereType, V>((__local_mem__ WhereType*)dst.GetPhyAddr(),
            (__local_mem__ WhereType*)src0.GetPhyAddr(), (__local_mem__ WhereType*)src1.GetPhyAddr(), 0, 0, conditionUb, count, repeatTime);
        } else {
            WhereCompute<false, false, uint64_t, V, MicroAPI::RegTraitNumTwo>((__local_mem__ uint64_t*)dst.GetPhyAddr(),
            (__local_mem__ uint64_t*)src0.GetPhyAddr(), (__local_mem__ uint64_t*)src1.GetPhyAddr(), 0, 0, conditionUb, count, repeatTime);
        }
    } else if constexpr (TypeUtils::IsLocalTensorType<U>() && TypeUtils::IsInnerDefaultType<S>()) {
        static_assert(Std::is_same<T, S>::value);
        static_assert(Std::is_same<T, typename U::PrimType>::value);
        if constexpr (sizeof(T) != 8) {
            WhereCompute<false, true, WhereType, V>((__local_mem__ WhereType*)dst.GetPhyAddr(),
            (__local_mem__ WhereType*)src0.GetPhyAddr(), nullptr, 0, (WhereType&)src1, conditionUb, count, repeatTime);
        } else {
            WhereCompute<false, true, uint64_t, V, MicroAPI::RegTraitNumTwo>((__local_mem__ uint64_t*)dst.GetPhyAddr(),
            (__local_mem__ uint64_t*)src0.GetPhyAddr(), nullptr, 0, (uint64_t&)src1, conditionUb, count, repeatTime);
        }
    } else if constexpr (TypeUtils::IsLocalTensorType<S>() && TypeUtils::IsInnerDefaultType<U>()) {
        static_assert(Std::is_same<T, U>::value);
        static_assert(Std::is_same<T, typename S::PrimType>::value);
        if constexpr (sizeof(T) != 8) {
            WhereCompute<true, false, WhereType, V>((__local_mem__ WhereType*)dst.GetPhyAddr(),
            nullptr, (__local_mem__ WhereType*)src1.GetPhyAddr(), (WhereType&)src0, 0, conditionUb, count, repeatTime);
        } else {
            WhereCompute<true, false, uint64_t, V, MicroAPI::RegTraitNumTwo>((__local_mem__ uint64_t*)dst.GetPhyAddr(),
            nullptr, (__local_mem__ uint64_t*)src1.GetPhyAddr(), (uint64_t&)src0, 0, conditionUb, count, repeatTime);
        }
    } else {
        static_assert(Std::is_same<T, U>::value);
        static_assert(Std::is_same<T, S>::value);
        if constexpr (sizeof(T) != 8) {
            WhereCompute<true, true, WhereType, V>((__local_mem__ WhereType*)dst.GetPhyAddr(),
            nullptr, nullptr, (WhereType&)src0, (WhereType&)src1, conditionUb, count, repeatTime);
        } else {
            WhereCompute<true, true, uint64_t, V, MicroAPI::RegTraitNumTwo>((__local_mem__ uint64_t*)dst.GetPhyAddr(),
            nullptr, nullptr, (uint64_t&)src0, (uint64_t&)src1, conditionUb, count, repeatTime);
        }
    }
}
} // namespace AscendC
#endif // IMPL_MATH_WHERE_WHERE_C310_IMPL_H
